import torch
import time
import model
import torch.optim
import torch.nn
from dataset import ParkingSlotDataset
from torch.utils.data import DataLoader
import myloss
import os
import val
import ps_evaluate
from myconfig import *
from thop import profile

os.environ["CUDA_VISIBLE_DEVICS"] = "0, 1"
torch.multiprocessing.set_sharing_strategy('file_system')
#
# weight_path = "F:\\study\\code\\mycode_3-22_with_relu\\weights\\"
# dataset_path = "G:/0A/output/data-input/output/"



batch_size = 16
num_workers = 16
f_epoch = 0

def train():
    device = torch.device('cuda:0')
    torch.set_grad_enabled(True)
    net = model.PSNet()
    device_ids = [0]
    net = net.cuda()
    net = torch.nn.DataParallel(net, device_ids)
    
    net.module.load_state_dict(torch.load(weight_path+weight_name, map_location='cuda:0'))
    
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-5)
    # optimizer.load_state_dict(torch.load(weight_path+op_name, map_location='cuda:0'))
    data = ParkingSlotDataset(dataset_path)
    data_loader = DataLoader(data,
                             batch_size=batch_size, shuffle=True,
                             num_workers=num_workers,
                             collate_fn=lambda x: list(zip(*x)))
    num_batch = int(data_loader.batch_sampler.sampler.num_samples / data_loader.batch_size)

    Loss = myloss.C_Loss()
    aver_loss = 0

    for epoch_idx in range(200):
        net.train()
        all_loss_pos = 0
        all_loss_angle = 0
        aver_loss_pos = 0
        aver_loss_angle = 0
        aver_time = 0
        for iter_idx, (images, target_pos, target_angle, mark_points) in enumerate(data_loader):
            start = time.time()
            optimizer.zero_grad()
            images = torch.stack(images).to(device)
            target_pos = torch.stack(target_pos)
            target_pos = target_pos.to(device)
            target_angle = torch.stack(target_angle).to(device)
            output_pos, output_angle = net(images)
            loss_pos, loss_angle, gradient = Loss(output_pos, output_angle, target_pos, target_angle, mark_points,device)
            loss_pos.backward(retain_graph=True)
            loss_angle.backward(gradient)
            optimizer.step()
            all_loss_pos += loss_pos.item()
            loss_angle_item = torch.sum(torch.mul(loss_angle, gradient)).item()
            all_loss_angle += loss_angle_item
            aver_loss_pos = all_loss_pos / (iter_idx + 1)
            aver_loss_angle = all_loss_angle / (iter_idx + 1)
            # if iter_idx % (num_batch//3) == 0 and iter_idx > 100:
            #     torch.save(net.module.state_dict(), weight_path+"weight{}_{}_{:.7f}_{:.7f}.pth".
            #                format(epoch_idx + f_epoch, iter_idx // (num_batch//3), aver_loss_pos, aver_loss_angle))
            end = time.time()
            aver_time = (aver_time+end-start)/2
            print("loss_pos:", round(loss_pos.item(), 7),
                  "loss_angle_item:", round(loss_angle_item, 7),
                  " iter:{}/{}".format(iter_idx,
                                       num_batch),
                  "epoch_idx:{}".format(epoch_idx+f_epoch),
                  " aver_loss:{:.7f}   {:.7f}".format(aver_loss_pos, aver_loss_angle),
                  " time:", round((end - start), 7),
                  "剩余时间：", round((num_batch-iter_idx)*aver_time/3600, 7))
        # val_loss = val.val(net, batch_size=batch_size*5, num_workers=num_workers, val_path=val_path, device=device)
        val_loss = [0, 0]
        torch.save(net.module.state_dict(), weight_path+"weight{}_{:.7f}_{:.7f}_end_{:.7f}_{:.7f}.pth".
                   format(epoch_idx + f_epoch, aver_loss_pos, aver_loss_angle, val_loss[0], val_loss[1]))
        # torch.save(optimizer.state_dict(), weight_path + "op{}_{:.7f}_{:.7f}_end_{:.7f}_{:.7f}.pth".
        #            format(epoch_idx + f_epoch, aver_loss_pos, aver_loss_angle, val_loss[0], val_loss[1]))


if __name__ == '__main__':
    train()
